#
# SPDX-FileCopyrightText: Copyright (c) 2021-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0#
"""Layer for implementing the channel in the time domain"""

import tensorflow as tf

from sionna.phy import Block
from . import GenerateTimeChannel, ApplyTimeChannel
from .utils import time_lag_discrete_time_channel

class TimeChannel(Block):
    # pylint: disable=line-too-long
    r"""
    Generates channel responses and applies them to channel inputs in the time domain

    The channel output consists of ``num_time_samples`` + ``l_max`` - ``l_min``
    time samples, as it is the result of filtering the channel input of length
    ``num_time_samples`` with the time-variant channel filter  of length
    ``l_max`` - ``l_min`` + 1. In the case of a single-input single-output link and given a sequence of channel
    inputs :math:`x_0,\cdots,x_{N_B}`, where :math:`N_B` is ``num_time_samples``, this
    layer outputs

    .. math::
        y_b = \sum_{\ell = L_{\text{min}}}^{L_{\text{max}}} x_{b-\ell} \bar{h}_{b,\ell} + w_b

    where :math:`L_{\text{min}}` corresponds ``l_min``, :math:`L_{\text{max}}` to ``l_max``, :math:`w_b` to
    the additive noise, and :math:`\bar{h}_{b,\ell}` to the
    :math:`\ell^{th}` tap of the :math:`b^{th}` channel sample.
    This layer outputs :math:`y_b` for :math:`b` ranging from :math:`L_{\text{min}}` to
    :math:`N_B + L_{\text{max}} - 1`, and :math:`x_{b}` is set to 0 for :math:`b < 0` or :math:`b \geq N_B`.
    The channel taps :math:`\bar{h}_{b,\ell}` are computed assuming a sinc filter
    is used for pulse shaping and receive filtering. Therefore, given a channel impulse response
    :math:`(a_{m}(t), \tau_{m}), 0 \leq m \leq M-1`, generated by the ``channel_model``,
    the channel taps are computed as follows:

    .. math::
        \bar{h}_{b, \ell}
        = \sum_{m=0}^{M-1} a_{m}\left(\frac{b}{W}\right)
            \text{sinc}\left( \ell - W\tau_{m} \right)

    for :math:`\ell` ranging from ``l_min`` to ``l_max``, and where :math:`W` is
    the ``bandwidth``.

    For multiple-input multiple-output (MIMO) links, the channel output is computed for each antenna of each receiver and by summing over all the antennas of all transmitters.

    Parameters
    ----------
    channel_model : :class:`~sionna.phy.channel.ChannelModel`
        Used channel model

    bandwidth : `float`
        Bandwidth (:math:`W`) [Hz]

    num_time_samples : `int`
        Number of time samples forming the channel input (:math:`N_B`)

    maximum_delay_spread : `float`, (default 3e-6)
        Maximum delay spread [s].
        Used to compute the default value of ``l_max`` if ``l_max`` is set to
        `None`. If a value is given for ``l_max``, this parameter is not used.
        It defaults to 3us, which was found
        to be large enough to include most significant paths with all channel
        models included in Sionna assuming a nominal delay spread of 100ns.

    l_min : `None` (default) | `int`
        Smallest time-lag for the discrete complex baseband channel (:math:`L_{\text{min}}`).
        If set to `None`, defaults to the value given by :func:`time_lag_discrete_time_channel`.

    l_max : `None` (default) | `int`
        Largest time-lag for the discrete complex baseband channel (:math:`L_{\text{max}}`).
        If set to `None`, it is computed from ``bandwidth`` and ``maximum_delay_spread``
        using :func:`time_lag_discrete_time_channel`. If it is not set to `None`,
        then the parameter ``maximum_delay_spread`` is not used.

    normalize_channel : `bool`, (default `False`)
        If set to `True`, the channel is normalized over the block size
        to ensure unit average energy per time step.

    return_channel : `bool`, (default `False`)
        If set to `True`, the channel response is returned in addition to the
        channel output.

    precision : `None` (default) | "single" | "double"
        Precision used for internal calculations and outputs.
        If set to `None`,
        :attr:`~sionna.phy.config.Config.precision` is used.

    Input
    -----
    x :  [batch size, num_tx, num_tx_ant, num_time_samples], `tf.complex`
        Channel inputs

    no : `None` (default) | Tensor, `tf.float`
        Tensor whose shape can be broadcast to the shape of the
        channel outputs: [batch size, num_rx, num_rx_ant, num_time_samples].
        The (optional) noise power ``no`` is per complex dimension. If ``no`` is a scalar,
        noise of the same variance will be added to the outputs.
        If ``no`` is a tensor, it must have a shape that can be broadcast to
        the shape of the channel outputs. This allows, e.g., adding noise of
        different variance to each example in a batch. If ``no`` has a lower
        rank than the channel outputs, then ``no`` will be broadcast to the
        shape of the channel outputs by adding dummy dimensions after the last
        axis.

    Output
    -------
    y : [batch size, num_rx, num_rx_ant, num_time_samples + l_max - l_min], `tf.complex`
        Channel outputs
        The channel output consists of ``num_time_samples`` + ``l_max`` - ``l_min``
        time samples, as it is the result of filtering the channel input of length
        ``num_time_samples`` with the time-variant channel filter  of length
        ``l_max`` - ``l_min`` + 1.

    h_time : [batch size, num_rx, num_rx_ant, num_tx, num_tx_ant, num_time_samples + l_max - l_min, l_max - l_min + 1], `tf.complex`
        (Optional) Channel responses. Returned only if ``return_channel``
        is set to `True`.
        For each batch example, ``num_time_samples`` + ``l_max`` - ``l_min`` time
        steps of the channel realizations are generated to filter the channel input.
    """
    def __init__(self, channel_model, bandwidth, num_time_samples,
                 maximum_delay_spread=3e-6, l_min=None, l_max=None,
                 normalize_channel=False, return_channel=False,
                 precision=None, **kwargs):

        super().__init__(precision=precision, **kwargs)

        # Setting l_min and l_max to default values if not given by the user
        l_min_default, l_max_default = time_lag_discrete_time_channel(bandwidth,
                                                            maximum_delay_spread)
        if l_min is None:
            l_min = l_min_default
        if l_max is None:
            l_max = l_max_default

        self._cir_sampler = channel_model
        self._bandwidth = bandwidth
        self._num_time_steps = num_time_samples
        self._l_min = l_min
        self._l_max = l_max
        self._l_tot = l_max-l_min+1
        self._normalize_channel = normalize_channel
        self._return_channel = return_channel

        self._generate_channel = GenerateTimeChannel(self._cir_sampler,
                                                     self._bandwidth,
                                                     self._num_time_steps,
                                                     self._l_min,
                                                     self._l_max,
                                                     self._normalize_channel,
                                                     precision=self.precision)

        self._apply_channel = ApplyTimeChannel( self._num_time_steps,
                                                self._l_tot,
                                                precision=self.precision)

    def call(self, x, no=None):
        h_time = self._generate_channel(tf.shape(x)[0])
        y = self._apply_channel(x, h_time, no)
        if self._return_channel:
            return y, h_time
        else:
            return y
